#include <float.h>
#include "./Kalman_Filter_Velocity.hpp"

using Kalman_Filter_Velocity_Namespace::Sys_State_Union;

static union Sys_State_Union sys_state_data, sys_state_forecast_data, measure_data;         /* 用于储存状态量，预测状态量和测量值 */
static arm_matrix_instance_f32 sys_state_matrix, sys_state_forecast_matrix, measure_matrix; /* 状态量矩阵，预测状态量矩阵和测量矩阵 */

static float gyro_pre_z_rate = 0.0f; /* 上一个控制周期陀螺仪的角速度(rad/s) */

static float covariance_data[9], covariance_forecast_data[9];                 /* 用于存储协方差和预测协方差 */
static arm_matrix_instance_f32 covariance_matrix, covariance_forecast_matrix; /* 状态量矩阵和预测状态量矩阵 */

static float R_data[16], W_data[12], Q_data[9];              /* 存储噪声 */
static arm_matrix_instance_f32 R_matrix, W_matrix, Q_matrix; /* 控制噪声和测量噪声 */

static float K_data[9], I_data[9];                 /* 用于存储卡尔曼增益 */
static arm_matrix_instance_f32 K_matrix, I_matrix; /* 卡尔曼增益矩阵和单位矩阵 */

void Kalman_Filter_Velocity_Namespace::Init(const float v_noise, const float z_rate_noise, const float wheel_radius, const float distance_x, const float distance_y)
{
    int i = 0;
    arm_mat_init_f32(&sys_state_matrix, 3, 1, sys_state_data.data);
    arm_mat_init_f32(&sys_state_forecast_matrix, 3, 1, sys_state_forecast_data.data);
    arm_mat_init_f32(&measure_matrix, 3, 1, measure_data.data);
    /* 状态量初始值为0 */
    for (i = 0; i < 3; i++)
    {
        sys_state_data.data[i] = 0.0f;
        sys_state_forecast_data.data[i] = 0.0f;
        measure_data.data[i] = 0.0f;
    }

    arm_mat_init_f32(&covariance_matrix, 3, 3, covariance_data);
    arm_mat_init_f32(&covariance_forecast_matrix, 3, 3, covariance_forecast_data);
    /* 协方差初始值为0 */
    for (i = 0; i < 9; i++)
    {
        covariance_data[i] = 0.0f;
        covariance_forecast_data[i] = 0.0f;
    }

    arm_mat_init_f32(&R_matrix, 4, 4, R_data);
    /* 初始化控制噪声矩阵 */
    for (i = 0; i < 16; i++)
    {
        R_data[i] = 0.0f;
    }
    /* 电机执行噪声 */
    R_data[0] = R_data[5] = R_data[10] = R_data[15] = v_noise * wheel_radius * wheel_radius;

    arm_mat_init_f32(&W_matrix, 3, 4, W_data);
    float distance_temp = 1.0f / ((distance_x + distance_y) / 2.0f);

    W_data[0] = -1.0f / 4.0f;
    W_data[1] = 1.0f / 4.0f;
    W_data[2] = -1.0f / 4.0f;
    W_data[3] = 1.0f / 4.0f;
    W_data[4] = W_data[5] = W_data[6] = W_data[7] = 1.0f / 4.0f;
    W_data[8] = distance_temp / 4.0f;
    W_data[9] = -distance_temp / 4.0f;
    W_data[10] = -distance_temp / 4.0f;
    W_data[11] = distance_temp / 4.0f;

    arm_mat_init_f32(&Q_matrix, 3, 3, Q_data);
    /* 初始化测量噪声矩阵 */
    for (i = 0; i < 9; i++)
    {
        Q_data[i] = 0.0f;
    }
    /* 陀螺仪无法测量线速度，因此方差为正无穷 */
    Q_data[0] = FLT_MAX;
    Q_data[4] = FLT_MAX;
    Q_data[8] = z_rate_noise / 2.0f; /* 取前后两次平均值 */

    arm_mat_init_f32(&K_matrix, 3, 3, K_data);
    arm_mat_init_f32(&I_matrix, 3, 3, I_data);
    /* 协方差初始值为0 */
    for (i = 0; i < 9; i++)
    {
        K_data[i] = 0.0f;
        I_data[i] = 0.0f;
    }
    I_data[0] = I_data[4] = I_data[8] = 1.0f;
}

void Kalman_Filter_Velocity_Namespace::Kalman_Filter_Velocity(const union Input_Control_Union &input, /* 输入控制量 */
                                                              const float gyro_z_rate,                /* 陀螺仪的角速度测量量(rad/s) */
                                                              union Sys_State_Union &sys_state,       /* 计算返回得到的状态量 */
                                                              arm_matrix_instance_f32 &covariance     /* 计算返回得到的协方差矩阵 */
)
{
    arm_matrix_instance_f32 input_matrix;
    arm_mat_init_f32(&input_matrix, 4, 1, (float *)input.data);

    arm_mat_mult_f32(&W_matrix, &input_matrix, &sys_state_forecast_matrix); /* 计算状态预测量 */

    /* 矩阵A为零矩阵，因此如下直接计算得到预测协方差covariance_forecast_matrix */
    Kalman_Filter_Function::Cal_A_mul_B_A_Trans(W_matrix, R_matrix, covariance_forecast_matrix);

    Kalman_Filter_Function::Cal_Kalman_Gain(covariance_forecast_matrix, Q_matrix, K_matrix);                              /* 计算得到卡尔曼增益矩阵 */
    measure_data.w = (gyro_z_rate + gyro_pre_z_rate) / 2.0f;                                                              /* 更新测量值 */
    gyro_pre_z_rate = gyro_z_rate;                                                                                        /* 保存上一个控制周期的角速度 */
    Kalman_Filter_Function::Update_State_Variable(sys_state_forecast_matrix, K_matrix, measure_matrix, sys_state_matrix); /* 更新状态值 */
    Kalman_Filter_Function::Update_Covariance(covariance_forecast_matrix, I_matrix, K_matrix, covariance_matrix);

    sys_state = sys_state_data;
    covariance = covariance_matrix;
}
